from ready_data.cifar10 import Cifar10, CIFAR10_CLASSES
from ready_data.base import load_first_batch
from utils.plot import plot_nine_images


def run_prepare():
    cifar10 = Cifar10()
    cifar10.prepare_data()


def run_load_batch():
    cifar10 = Cifar10()
    cifar10.prepare_data()
    images, _ = load_first_batch(cifar10.train)
    img = images[0]
    channel1 = img[0]
    print(f'channel1 size: {channel1.size()}, max: {channel1.max()}, min: {channel1.min()}')


def run_divide():
    cifar10 = Cifar10()
    cifar10._download()
    cifar10._divide()


def run_plot_batch():
    cifar10 = Cifar10()
    cifar10.prepare_data()
    images, labels = load_first_batch(cifar10.full)
    plot_nine_images(images, labels, CIFAR10_CLASSES)


if __name__ == '__main__':
    run_plot_batch()